from BasicLibraries import *
import auxEffect
import utils_nc as ut
import statsmodels.api as sm
from statsmodels.discrete.discrete_model import Probit
import warnings
from statsmodels.tools.sm_exceptions import ConvergenceWarning
warnings.simplefilter('ignore', ConvergenceWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

'''
This file contains the function to estimate the regression coefficients
'''

BE = auxEffect.BehavioralEffect()

#==== MACRO PARAMETERS
GEO_DUMS = True ## Control for country fixed effects
#=== Dealing with missing values in Asset4 reports characteristics
#=== If True they are a missing class. Otherwise, they are a zero
#=== The default should be True since you cannot a-priori set missing values to zero
missing_class = True

#%%
def get_data(DTA, estMetd, region_var, initiatives, envSDGs, emissionVar,  \
             controls_list, controls_names, analysis_type):
    YearType = 'rfyear'
    DTA['mrg_f'] = DTA.index
    #=== Get sector and year dummies
    dt, g_dummies, s_dummies, y_dummies,  tmK = BE.get_vars_and_dummies(DTA, region_var, emissionVar, initiatives, envSDGs, YearType, analysis_type = analysis_type)
    
    vars_ = [emissionVar, 'Environmental SDGs',  'Mills_Ratio'] + controls_list
    vars_names = ['Emission', '\# of initiatives',  'Mills Ratio'] + controls_names


    #=== Get variable in shape
    Y, X, g_dummies, s_dummies, y_dummies,  dx = BE.make_vars(dt, YearType, vars_, \
                                                         g_dummies, s_dummies, y_dummies)
    return Y, X, g_dummies, s_dummies, y_dummies,   dx, vars_, vars_names, tmK


def ContinuousAlignment(DTA, un,  target_type, estMetd, region_var, initiatives, envSDGs, emissionVar,  \
                vars_continous, vars_named ,  
              make_table=False,  table_name='None.tex'):


    YearType = 'rfyear'
    MillsRatio = pd.read_csv('local_data/MillsRatio.csv')
    DTA = DTA.reset_index().merge(MillsRatio[['mrg', 'Mills_Ratio']])

    #=== Get sector and year dummies
    dt, g_dummies, s_dummies, y_dummies,  tmK = BE.get_vars_and_dummies(DTA, region_var, emissionVar, initiatives , envSDGs, YearType, analysis_type = 'Paris')
    dt['difs'] = dt[['gvkey', 'rfyear']].groupby('gvkey').diff().replace([np.nan], [0])
    dt = dt[dt.difs <= 1]
    #=== Get variable in shape
    X, g_dummies, s_dummies, y_dummies,  dx = BE.make_vars_continuous(dt, YearType, vars_continous, \
                                                                       g_dummies, s_dummies, y_dummies \
                                                                       )
    
    #=== Merge with alignment
    x = un[[target_type, 'mrg']].merge(X, on = 'mrg')
    for i in ['firm_size_H', 'investment_log_H', 'Tangibility_H']:
        x= x[x[i].abs()<x[i].quantile(0.99)]
    x  = x[x[target_type].abs()<  x[target_type].quantile(0.99)]
    #=== Scale the variables
    x[[target_type] + vars_continous[1:]] = x[[target_type] + vars_continous[1:]].apply(scale)

    x = x.dropna()
    if GEO_DUMS:
        listOfDummies = list(s_dummies.columns)  + list(g_dummies.columns)
    else:
        listOfDummies = list(s_dummies.columns) 

    controls = [vars_continous[1:i] for i in range(len(vars_continous)+1)][2:]



    ### Get Source indicator
    p_indicator = pd.read_csv('local_data/pa_methods.csv')
    x = x.merge(p_indicator[['source', 'mrg']])
    listOfDummies = listOfDummies+['source']



    c_indicator = pd.read_csv('local_data/reports_characteristics.csv')
    c_indicator = c_indicator.loc[c_indicator.fyear.dropna().index].reset_index(drop=True)
    c_indicator = c_indicator.groupby('mrg').last().reset_index()



    if missing_class:
        ### Get Report characteristics
        reports_characteristics = ['GRI_Report_Guidelines', 'CSR_Sustainability_Report_Global_Activities', 'CSR_Sustainability_External_Audit']
        c_indicator = c_indicator[reports_characteristics+['mrg']].replace([np.nan], [2]).set_index('mrg')
        gri = pd.get_dummies(c_indicator['GRI_Report_Guidelines'], drop_first=False)
        gri = gri.drop(columns = [2])
        gri.columns = ['gri_0', 'gri_1']

        gri['gvkey'] = gri.reset_index()['mrg'].apply(lambda x: x.split('-')[0]).values.tolist()
        gri['fyear'] = gri.reset_index()['mrg'].apply(lambda x: x.split('-')[1]).values.tolist()
        gri = gri.sort_values(by = ['gvkey', 'fyear'])
        gri[['gri_0a', 'gri_1a']] = gri.groupby('gvkey').shift(1)[['gri_0', 'gri_1']].dropna()
        gri = gri[['gri_0', 'gri_1', 'gri_0a', 'gri_1a']]
        
        aud = pd.get_dummies(c_indicator['CSR_Sustainability_External_Audit'], drop_first=False)
        aud = aud.drop(columns = [2])
        aud.columns = ['aud_0', 'aud_1']
        
        aud['gvkey'] = aud.reset_index()['mrg'].apply(lambda x: x.split('-')[0])
        aud['fyear'] = aud.reset_index()['mrg'].apply(lambda x: x.split('-')[1])
        aud = aud.sort_values(by = ['gvkey', 'fyear'])
        aud[['aud_0a', 'aud_1a']] = aud.groupby('gvkey').shift(1)[['aud_0', 'aud_1']].dropna()
        aud = aud[['aud_0', 'aud_1', 'aud_0a', 'aud_1a']]
        
        
        x = x.merge(gri.reset_index())
        x = x.merge(aud.reset_index())
        listOfDummies = listOfDummies+['gri_0', 'gri_1', 'aud_0', 'aud_1']
    else:
        ### Get Report characteristics
        reports_characteristics = ['GRI_Report_Guidelines', 'CSR_Sustainability_Report_Global_Activities', 'CSR_Sustainability_External_Audit']
        c_indicator = c_indicator[reports_characteristics+['mrg']].replace([np.nan], [0]).set_index('mrg')
        gri = pd.get_dummies(c_indicator['GRI_Report_Guidelines'], drop_first=True)
        gri.columns = ['gri_1']
        gri['gvkey'] = gri.reset_index()['mrg'].apply(lambda x: x.split('-')[0]).values.tolist()
        gri['fyear'] = gri.reset_index()['mrg'].apply(lambda x: x.split('-')[1]).values.tolist()
        gri = gri.sort_values(by = ['gvkey', 'fyear'])
        gri[['gri_1a']] = gri.groupby('gvkey').shift(1)[['gri_1']].dropna()
        gri = gri[[ 'gri_1', 'gri_1a']]
        
        aud = pd.get_dummies(c_indicator['CSR_Sustainability_External_Audit'], drop_first=True)
        aud.columns = ['aud_1']
        aud['gvkey'] = aud.reset_index()['mrg'].apply(lambda x: x.split('-')[0])
        aud['fyear'] = aud.reset_index()['mrg'].apply(lambda x: x.split('-')[1])
        aud = aud.sort_values(by = ['gvkey', 'fyear'])
        aud[[ 'aud_1a']] = aud.groupby('gvkey').shift(1)[[ 'aud_1']].dropna()
        aud = aud[[ 'aud_1',   'aud_1a']]
        
        
        x = x.merge(gri.reset_index())
        x = x.merge(aud.reset_index())
        listOfDummies = listOfDummies+['gri_1', 'aud_1', 'gri_1a', 'aud_1a']

    x = x.set_index('mrg')
    idx = x[listOfDummies].sum()
    trs = 5
    idxK = idx[idx > trs].dropna().index
    to_remove = list(idx[idx <= trs].dropna().index)
    for t in to_remove: listOfDummies.remove(t)


    table = []
    for i in controls:
        if estMetd == 'OLS':
            lm = sm.OLS(x[target_type], x[['Mills_Ratio'] +i+listOfDummies])
            lm = lm.fit(cov = 'HC2', cov_type='cluster', cov_kwds={'groups': x['gvkey']})
        elif estMetd == 'RLM':
            lm = sm.RLM(x[target_type], x[['Mills_Ratio'] +i+listOfDummies], M=sm.robust.norms.HuberT()).fit() #cov = 'H1')

        res = pd.concat((lm.params.round(2), lm.pvalues.round(3)), axis = 1)
        res.columns = ['parameters', 'p-value']
        res = pd.concat((lm.params.round(2), lm.pvalues.round(3)), axis = 1)
        res.columns = ['parameters', 'p-value']
        table.append([str(res.loc[k].parameters)+ut.utils().significance(res.loc[k]['p-value']) for k in i])
    table = pd.DataFrame(table, columns = vars_named[1:], index = ['Model '+str(i+1) for i in range(len(table))])

    return table, x, lm

def BinaryAlignment(DTA, un,  target_type, estMetd, region_var, initiatives, envSDGs, emissionVar,  \
                vars_continous, vars_named ,  
              make_table=False,  table_name='None.tex'):
    YearType = 'rfyear'
    MillsRatio = pd.read_csv('local_data/MillsRatio.csv')
    DTA = DTA.reset_index().merge(MillsRatio[['mrg', 'Mills_Ratio']])

    #=== Get sector and year dummies
    dt, g_dummies, s_dummies, y_dummies,  tmK = BE.get_vars_and_dummies(DTA, region_var, emissionVar, initiatives , envSDGs, YearType, analysis_type = 'Paris')
    dt['difs'] = dt[['gvkey', 'rfyear']].groupby('gvkey').diff().replace([np.nan], [0])
    dt = dt[dt.difs <= 1]

    #=== Get variable in shape
    X, g_dummies, s_dummies, y_dummies,   dx = BE.make_vars_continuous(dt, YearType, vars_continous, \
                                                                        g_dummies, s_dummies, y_dummies \
                                                                           )
    
    #=== Merge with alignment
    un_ = un.copy()
    un_['signed'] = un_['signed'].replace([1,-1], [0,1])
    x = un_[['signed', 'mrg']].merge(X, on = 'mrg')
    for i in ['firm_size_H', 'investment_log_H', 'Tangibility_H']:
        x= x[x[i].abs()<x[i].quantile(0.99)]
    #=== Scale the variables
    x[ vars_continous[1:]] = x[vars_continous[1:]].apply(scale)
    x = x.dropna()
    if GEO_DUMS:
        listOfDummies = list(s_dummies.columns)  + list(g_dummies.columns)
    else:
        listOfDummies = list(s_dummies.columns) 

    controls = [vars_continous[1:i] for i in range(len(vars_continous)+1)][2:]


    ### Get Source indicator
    p_indicator = pd.read_csv('local_data/pa_methods.csv')
    x = x.merge(p_indicator[['source', 'mrg']])
    listOfDummies = listOfDummies+['source']

    MillsRatio = pd.read_csv('local_data/MillsRatio.csv')
    x = x.drop(columns = ['Mills_Ratio'])
    x = x.merge(MillsRatio[['mrg', 'Mills_Ratio']])
    
    c_indicator = pd.read_csv('local_data/reports_characteristics.csv')
    c_indicator = c_indicator.loc[c_indicator.fyear.dropna().index].reset_index(drop=True)
    c_indicator = c_indicator.groupby('mrg').last().reset_index()

    if missing_class:
        ### Get Report characteristics
        reports_characteristics = ['GRI_Report_Guidelines', 'CSR_Sustainability_Report_Global_Activities', 'CSR_Sustainability_External_Audit']
        c_indicator = c_indicator[reports_characteristics+['mrg']].replace([np.nan], [2]).set_index('mrg')
        gri = pd.get_dummies(c_indicator['GRI_Report_Guidelines'], drop_first=False)
        gri = gri.drop(columns = [2])
        gri.columns = ['gri_0', 'gri_1']

        gri['gvkey'] = gri.reset_index()['mrg'].apply(lambda x: x.split('-')[0]).values.tolist()
        gri['fyear'] = gri.reset_index()['mrg'].apply(lambda x: x.split('-')[1]).values.tolist()
        gri = gri.sort_values(by = ['gvkey', 'fyear'])
        gri[['gri_0a', 'gri_1a']] = gri.groupby('gvkey').shift(1)[['gri_0', 'gri_1']]
        gri = gri.dropna()
        gri = gri[['gri_0', 'gri_1', 'gri_0a', 'gri_1a']]
        
        aud = pd.get_dummies(c_indicator['CSR_Sustainability_External_Audit'], drop_first=False)
        aud = aud.drop(columns = [2])
        aud.columns = ['aud_0', 'aud_1']
        
        aud['gvkey'] = aud.reset_index()['mrg'].apply(lambda x: x.split('-')[0]).values.tolist()
        aud['fyear'] = aud.reset_index()['mrg'].apply(lambda x: x.split('-')[1]).values.tolist()
        aud = aud.sort_values(by = ['gvkey', 'fyear'])
        aud[['aud_0a', 'aud_1a']] = aud.groupby('gvkey').shift(1)[['aud_0', 'aud_1']]
        aud = aud.dropna()
        aud = aud[['aud_0', 'aud_1', 'aud_0a', 'aud_1a']]
        
        
        x = x.merge(gri.reset_index())
        x = x.merge(aud.reset_index())
        listOfDummies = listOfDummies+['gri_0', 'gri_1', 'gri_0a', 'gri_1a', 'aud_0', 'aud_1', 'aud_0a', 'aud_1a']
    else:
        ### Get Report characteristics
        reports_characteristics = ['GRI_Report_Guidelines', 'CSR_Sustainability_Report_Global_Activities', 'CSR_Sustainability_External_Audit']
        c_indicator = c_indicator[reports_characteristics+['mrg']].replace([np.nan], [0]).set_index('mrg')
        gri = pd.get_dummies(c_indicator['GRI_Report_Guidelines'], drop_first=True)
        gri.columns = ['gri_1']
        gri['gvkey'] = gri.reset_index()['mrg'].apply(lambda x: x.split('-')[0]).values.tolist()
        gri['fyear'] = gri.reset_index()['mrg'].apply(lambda x: x.split('-')[1]).values.tolist()
        gri = gri.sort_values(by = ['gvkey', 'fyear'])
        gri[['gri_1a']] = gri.groupby('gvkey').shift(1)[['gri_1']]
        gri = gri.dropna()
        gri = gri[[ 'gri_1', 'gri_1a']]
        
        aud = pd.get_dummies(c_indicator['CSR_Sustainability_External_Audit'], drop_first=True)
        aud.columns = ['aud_1']
        aud['gvkey'] = aud.reset_index()['mrg'].apply(lambda x: x.split('-')[0]).values.tolist()
        aud['fyear'] = aud.reset_index()['mrg'].apply(lambda x: x.split('-')[1]).values.tolist()
        aud = aud.sort_values(by = ['gvkey', 'fyear'])
        aud[[ 'aud_1a']] = aud.groupby('gvkey').shift(1)[[ 'aud_1']]
        aud = aud.dropna()
        aud = aud[[ 'aud_1',   'aud_1a']]
        
        
        x = x.merge(gri.reset_index())
        x = x.merge(aud.reset_index())
        listOfDummies = listOfDummies+['gri_1', 'aud_1', 'gri_1a', 'aud_1a']


    x = x.set_index('mrg')
    idx = x[listOfDummies].sum()
    trs = 5
    idxK = idx[idx > trs].dropna().index
    to_remove = list(idx[idx <= trs].dropna().index)
    for t in to_remove: listOfDummies.remove(t)


    table = []
    for i in controls:
        lm = Probit(x['signed'], x[i+listOfDummies]).fit(disp=0)
        res = pd.concat((lm.params.round(2), lm.pvalues.round(3)), axis = 1)
        res.columns = ['parameters', 'p-value']
        table.append([str(res.loc[k].parameters)+ut.utils().significance(res.loc[k]['p-value']) for k in i])
    table = pd.DataFrame(table, columns = vars_named[1:], index = ['Model '+str(i+1) for i in range(len(table))])

    return table, x, lm


        
        
        
        
        
        